import argparse
import datetime
import json
import os
import time
import re
import base64
from io import BytesIO
from PIL import Image, ImageDraw
import gradio as gr
import requests
import hashlib

# Anonymized internal imports
from my_project.conversation import (default_conversation, conv_templates,
                                   SeparatorStyle)
from my_project.constants import LOGDIR
from my_project.utils import (build_logger, server_error_msg,
    violates_moderation, moderation_msg)

# --- Utility Functions ---

def b64_to_pil(b64_str: str):
    """base64 -> PIL.Image"""
    if b64_str.startswith('data:image'):
        b64_str = b64_str.split(',')[-1]
    return Image.open(BytesIO(base64.b64decode(b64_str))).convert("RGB")

def pil_to_img_tag(img: Image.Image):
    buf = BytesIO()
    img.save(buf, format="JPEG")
    b64 = base64.b64encode(buf.getvalue()).decode()
    return f'<img src="data:image/jpeg;base64,{b64}" alt="tool_vis" />'

def draw_boxes(img_b64: str, boxes):
    img = b64_to_pil(img_b64)
    W, H = img.size
    draw = ImageDraw.Draw(img)
    for bx in boxes:
        x1, y1, x2, y2 = bx
        if max(x1, y1, x2, y2) > 1.5:
            x1, y1, x2, y2 = x1 / W, y1 / H, x2 / W, y2 / H
        draw.rectangle([x1*W, y1*H, x2*W, y2*H], outline=(255, 0, 0), width=3)
    html = pil_to_img_tag(img)
    return html, img

def overlay_mask(img_b64: str, masks_data):
    import numpy as np
    try:
        from pycocotools import mask as mask_utils
    except ImportError:
        logger.error("pycocotools not available for mask visualization")
        img = b64_to_pil(img_b64)
        return pil_to_img_tag(img), img
    
    img = b64_to_pil(img_b64).convert("RGBA")
    W, H = img.size
    
    if isinstance(masks_data, dict):
        for class_name, rle in masks_data.items():
            if rle:
                mask = mask_utils.decode(rle)
                overlay = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
                overlay[mask > 0] = [255, 0, 0, 120]
                overlay_img = Image.fromarray(overlay, "RGBA")
                if overlay_img.size != (W, H):
                    overlay_img = overlay_img.resize((W, H), Image.NEAREST)
                img = Image.alpha_composite(img, overlay_img)
                break
    
    html = pil_to_img_tag(img.convert("RGB"))
    return html, img.convert("RGB")

# --- Tool Call Logic ---

TOOL_SCHEMA = {
    "grounding_dino": {"endpoint": "/worker_generate", "vis_fn": draw_boxes},
    "lae_dino": {"endpoint": "/worker_generate_stream", "vis_fn": draw_boxes},
    "remotesam_dino": {"endpoint": "/worker_generate", "vis_fn": draw_boxes},
    "remote_sam": {"endpoint": "/worker_generate", "vis_fn": overlay_mask},
    "denodet": {"endpoint": "/worker_generate", "vis_fn": draw_boxes},
}

def compress_gdino(output: dict):
    out = output.copy()
    if 'boxes' in out:
        out['boxes'] = [[round(b, 2) for b in bb] for bb in out['boxes']]
    if 'scores' in out:
        out['scores'] = [round(float(s), 3) for s in out['scores']]
    if 'logits' in out:
        out['logits'] = [round(float(l), 3) for l in out['logits']]
    for k in ['masks_rle', 'edited_image', 'size', 'image_seg', 'iou_sort_masks']:
        out.pop(k, None)
    if out.get('tool_name') == 'easyocr':
        out.pop('boxes', None)
        out.pop('scores', None)
    return out

def call_tool(controller_url: str, tool_name: str, params: dict):
    addr_ret = requests.post(controller_url + "/get_worker_address", json={"model": tool_name})
    worker_addr = addr_ret.json().get("address", "")
    if not worker_addr:
        logger.info(f"No worker available for {tool_name}")
        return None, None

    endpoint = TOOL_SCHEMA.get(tool_name, {}).get("endpoint", "/worker_generate")
    res_json = None
    try:
        if endpoint.endswith('_stream'):
            # Simplified stream handling for brevity
            rsp = requests.post(worker_addr + endpoint, json=params, timeout=180, stream=True)
            for line in rsp.iter_lines():
                if line:
                    decoded_line = line.decode('utf-8').strip('\0')
                    try:
                        payload = json.loads(decoded_line)
                        text_field = payload.get("text", "{}")
                        res_json = json.loads(text_field)
                    except (json.JSONDecodeError, TypeError):
                        res_json = text_field
                        continue
        else:
            res_json = requests.post(worker_addr + endpoint, json=params, timeout=180).json()
    except Exception as e:
        logger.info(f"Tool {tool_name} call error: {e}")
        return None, None

    if res_json is None:
        return None, None

    if ("boxes" not in res_json) and isinstance(res_json.get("text"), dict):
        res_json = res_json["text"]

    vis_html = None
    vis_fn = TOOL_SCHEMA.get(tool_name, {}).get("vis_fn")
    if vis_fn and "image" in params and params["image"]:
        try:
            key = "masks" if "masks" in res_json else "boxes"
            vis_html = vis_fn(params["image"], res_json.get(key, {}))
        except Exception as e:
            logger.info(f"vis error: {e}")

    return res_json, vis_html

# --- Gradio Core Logic ---

logger = build_logger("gradio_web_server", "gradio_web_server.log")
headers = {"User-Agent": "MyProject Client"}  # Anonymized

no_change_btn = gr.Button.update()
enable_btn = gr.Button.update(interactive=True)
disable_btn = gr.Button.update(interactive=False)

priority = {"vicuna-13b": "aaaaaaa", "koala-13b": "aaaaaab"}

def get_conv_log_filename():
    t = datetime.datetime.now()
    name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
    return name

def get_model_list():
    try:
        ret = requests.post(args.controller_url + "/refresh_all_workers")
        assert ret.status_code == 200
        ret = requests.post(args.controller_url + "/list_models")
        models = ret.json()["models"]
        models.sort(key=lambda x: priority.get(x, x))
        logger.info(f"Models: {models}")
        return models
    except Exception as e:
        logger.error(f"Could not get model list: {e}")
        return []

get_window_url_params = """
function() {
    const params = new URLSearchParams(window.location.search);
    url_params = Object.fromEntries(params);
    console.log(url_params);
    return url_params;
}
"""

def load_demo(url_params, request: gr.Request):
    logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
    dropdown_update = gr.Dropdown.update(visible=True)
    if "model" in url_params:
        model = url_params["model"]
        if model in models:
            dropdown_update = gr.Dropdown.update(value=model, visible=True)
    state = default_conversation.copy()
    return state, dropdown_update

def load_demo_refresh_model_list(request: gr.Request):
    logger.info(f"load_demo. ip: {request.client.host}")
    models = get_model_list()
    state = default_conversation.copy()
    dropdown_update = gr.Dropdown.update(
        choices=models, value=models[0] if models else ""
    )
    return state, dropdown_update

def vote_last_response(state, vote_type, model_selector, request: gr.Request):
    with open(get_conv_log_filename(), "a") as fout:
        data = {"tstamp": round(time.time(), 4), "type": vote_type, "model": model_selector,
                "state": state.dict(), "ip": request.client.host}
        fout.write(json.dumps(data) + "\n")

def upvote_last_response(state, model_selector, request: gr.Request):
    logger.info(f"upvote. ip: {request.client.host}")
    vote_last_response(state, "upvote", model_selector, request)
    return ("",) + (disable_btn,) * 3

def downvote_last_response(state, model_selector, request: gr.Request):
    logger.info(f"downvote. ip: {request.client.host}")
    vote_last_response(state, "downvote", model_selector, request)
    return ("",) + (disable_btn,) * 3

def flag_last_response(state, model_selector, request: gr.Request):
    logger.info(f"flag. ip: {request.client.host}")
    vote_last_response(state, "flag", model_selector, request)
    return ("",) + (disable_btn,) * 3

def regenerate(state, image_process_mode, request: gr.Request):
    logger.info(f"regenerate. ip: {request.client.host}")
    state.messages[-1][-1] = None
    if state.messages and state.messages[-2][0] == state.roles[0]:
        prev_human_msg = state.messages[-2]
        if type(prev_human_msg[1]) in (tuple, list):
            prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
    state.skip_next = False
    return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5

def clear_history(request: gr.Request):
    logger.info(f"clear_history. ip: {request.client.host}")
    state = default_conversation.copy()
    return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5

def add_text(state, text, image, image_process_mode, request: gr.Request):
    logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
    if not text.strip() and image is None:
        state.skip_next = True
        return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
    if args.moderate and violates_moderation(text):
        state.skip_next = True
        return (state, state.to_gradio_chatbot(), moderation_msg, None) + (no_change_btn,) * 5

    text = text[:1536]
    if image is not None:
        text = text[:1200]
        if '<image>' not in text:
            text += '\n<image>'
        text = (text, image, image_process_mode)
        if len(state.get_images(return_pil=True)) > 0:
            state = default_conversation.copy()
    
    state.append_message(state.roles[0], text)
    state.append_message(state.roles[1], None)
    state.skip_next = False
    return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5

def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
    logger.info(f"http_bot. ip: {request.client.host}")
    start_tstamp = time.time()
    model_name = model_selector

    if state.skip_next:
        yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
        return

    if len(state.messages) == state.offset + 2:
        # Anonymized template selection
        template_name = "vicuna_v1" # Default
        if "my_project" in model_name.lower():
            if 'llama-2' in model_name.lower():
                template_name = "my_project_llama_2"
            elif "v1" in model_name.lower():
                template_name = "my_project_v1"
            else:
                template_name = "my_project_v0"
        elif "mpt" in model_name.lower():
            template_name = "mpt"
        elif "llama-2" in model_name.lower():
            template_name = "llama_2"
        
        new_state = conv_templates[template_name].copy()
        new_state.append_message(new_state.roles[0], state.messages[-2][1])
        new_state.append_message(new_state.roles[1], None)
        state = new_state

    controller_url = args.controller_url
    ret = requests.post(controller_url + "/get_worker_address", json={"model": model_name})
    worker_addr = ret.json()["address"]
    logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")

    if worker_addr == "":
        state.messages[-1][-1] = server_error_msg
        yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
        return

    prompt = state.get_prompt()
    prompt = re.sub(r"<img src=\"data:image[^>]+>", "<tool_vis>", prompt)
    
    all_images = state.get_images(return_pil=True)
    all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
    for image, hash_val in zip(all_images, all_image_hash):
        t = datetime.datetime.now()
        filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash_val}.jpg")
        if not os.path.isfile(filename):
            os.makedirs(os.path.dirname(filename), exist_ok=True)
            image.save(filename)

    pload = {
        "model": model_name, "prompt": prompt, "temperature": float(temperature),
        "top_p": float(top_p), "max_new_tokens": min(int(max_new_tokens), 1536),
        "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
        "images": state.get_images()
    }
    logger.info(f"==== request ====\n{pload}")

    state.messages[-1][-1] = "▌"
    yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5

    try:
        response = requests.post(worker_addr + "/worker_generate_stream",
                                 headers=headers, json=pload, stream=True, timeout=10)
        for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
            if chunk:
                data = json.loads(chunk.decode())
                error_code = data.get("error_code", 0)
                output = data["text"][len(prompt):].strip() if error_code == 0 else data.get("text", "Unknown error")
                state.messages[-1][-1] = output + ("▌" if error_code == 0 else "")
                yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
                if error_code != 0:
                    yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
                    return
    except requests.exceptions.RequestException as e:
        state.messages[-1][-1] = server_error_msg
        yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
        return

    state.messages[-1][-1] = state.messages[-1][-1].rstrip("▌")
    yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
    
    # Tool pipeline logic would continue here, but is omitted for this example as it is complex
    # and assumed to be correctly implemented without identity leaks.

    finish_tstamp = time.time()
    with open(get_conv_log_filename(), "a") as fout:
        data = {
            "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name,
            "start": round(start_tstamp, 4), "finish": round(finish_tstamp, 4),
            "state": state.dict(), "images": all_image_hash, "ip": request.client.host,
        }
        fout.write(json.dumps(data) + "\n")

# --- UI Layout and Text ---

title_markdown = """
# Model Demo
This is a research demonstration of a vision-language model.
"""
tos_markdown = """
### Terms of Use
By using this service, users are required to agree to the following terms:
The service is a research preview intended for non-commercial use only. It must not be used for any illegal, harmful, or offensive purposes. The service may collect user dialogue data for future research.
"""
learn_more_markdown = """
### License
The service is a research preview. It is subject to the licenses of the models and data used in its training.
"""

block_css = """
#buttons button { min-width: min(120px,100%); }
"""

def build_demo(embed_mode):
    textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
    with gr.Blocks(title="Model Demo", theme=gr.themes.Default(), css=block_css) as demo: # Anonymized
        state = gr.State()

        if not embed_mode:
            gr.Markdown(title_markdown)

        with gr.Row():
            with gr.Column(scale=3):
                with gr.Row(elem_id="model_selector_row"):
                    model_selector = gr.Dropdown(
                        choices=models, value=models[0] if models else "",
                        interactive=True, show_label=False, container=False)

                imagebox = gr.Image(type="pil")
                image_process_mode = gr.Radio(
                    ["Crop", "Resize", "Pad", "Default"], value="Default",
                    label="Preprocess for non-square image", visible=False)

                cur_dir = os.path.dirname(os.path.abspath(__file__))
                gr.Examples(examples=[
                    [f"{cur_dir}/examples/11765.jpg", "What are the types of airplanes present?"],
                    [f"{cur_dir}/examples/11760.jpg", "What are your thoughts on urban planning in this region?"],
                ], inputs=[imagebox, textbox])

                with gr.Accordion("Parameters", open=False):
                    temperature = gr.Slider(0.0, 1.0, 0.2, step=0.1, interactive=True, label="Temperature")
                    top_p = gr.Slider(0.0, 1.0, 0.7, step=0.1, interactive=True, label="Top P")
                    max_output_tokens = gr.Slider(0, 1024, 512, step=64, interactive=True, label="Max output tokens")

            with gr.Column(scale=8):
                chatbot = gr.Chatbot(elem_id="chatbot", label="Model Demo Bot", height=550) # Anonymized
                with gr.Row():
                    with gr.Column(scale=8):
                        textbox.render()
                    with gr.Column(scale=1, min_width=50):
                        submit_btn = gr.Button("Send", variant="primary")
                with gr.Row(elem_id="buttons"):
                    upvote_btn = gr.Button("👍  Upvote", interactive=False)
                    downvote_btn = gr.Button("👎  Downvote", interactive=False)
                    flag_btn = gr.Button("⚠️  Flag", interactive=False)
                    regenerate_btn = gr.Button("🔄  Regenerate", interactive=False)
                    clear_btn = gr.Button("🗑️  Clear", interactive=False)

        if not embed_mode:
            gr.Markdown(tos_markdown)
            gr.Markdown(learn_more_markdown)
        url_params = gr.JSON(visible=False)

        btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
        upvote_btn.click(upvote_last_response, [state, model_selector], [textbox] + btn_list[:3])
        downvote_btn.click(downvote_last_response, [state, model_selector], [textbox] + btn_list[:3])
        flag_btn.click(flag_last_response, [state, model_selector], [textbox] + btn_list[:3])
        regenerate_btn.click(regenerate, [state, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list).then(
            http_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list)
        clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list)

        textbox.submit(add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list).then(
            http_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list)
        submit_btn.click(add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list).then(
            http_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list)

        if args.model_list_mode == "once":
            demo.load(load_demo, [url_params], [state, model_selector], _js=get_window_url_params)
        elif args.model_list_mode == "reload":
            demo.load(load_demo_refresh_model_list, None, [state, model_selector])
        else:
            raise ValueError(f"Unknown model list mode: {args.model_list_mode}")

    return demo

def parse_tool_output(text: str):
    pattern = r'"reasoning💭"\s*(.*?)\s*"actions⚡"\s*(.*?)\s*"value✨"\s*(.*)'
    m = re.findall(pattern, text, re.DOTALL)
    return m if m else None

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="0.0.0.0")
    parser.add_argument("--port", type=int)
    parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
    parser.add_argument("--concurrency-count", type=int, default=10)
    parser.add_argument("--model-list-mode", type=str, default="once", choices=["once", "reload"])
    parser.add_argument("--share", action="store_true")
    parser.add_argument("--moderate", action="store_true")
    parser.add_argument("--embed", action="store_true")
    args = parser.parse_args()
    logger.info(f"args: {args}")

    models = get_model_list()

    demo = build_demo(args.embed)
    demo.queue(
        concurrency_count=args.concurrency_count,
        api_open=False
    ).launch(
        server_name=args.host,
        server_port=args.port,
        share=args.share
    )